{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5fa051c8-8f5e-4809-b90e-bf129a701352",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d549cf85-c638-4dec-a436-254da7060ee3",
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoints = torch.load('../output/model_Epoch_00030_Iter_0000029.pth', 'cpu')['model']\n",
    "new_keys = list(list(checkpoints.keys()))\n",
    "new_keys[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "81fd1bd8-d8ed-457f-a9ec-d2b0bf91c8c7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'/mnt/lustre/zhujinguo/jinguo_data/codes/Uni-Perceiver/work_dirs'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import os\n",
    "os.getcwd()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d9a698e1-de54-4257-a45f-65cd2d7cf095",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_checkpoints = torch.load('deepspeed_moe/BERT_L12_H192_experiments/7task_150k_berttiny_lr1e-3_wd0.05_gc0.1_prenorm_warm10k_layerscale1e-3_uniformdp0.1_maeinit_fixedpos_torchfp16_unifieddataset_224inputsize_tagmoe_alllayer/tagmoe_alllayer_exp4/149999/mp_rank_00_model_states.pt', 'cpu')['module']\n",
    "# ds_checkpoints = torch.load('/nfs/zhujinguo/codes/xmodaler/work_dirs/deepspeed_moe/BERT_L12_H768_experiments/basetagmoe_pretrainstage2/89999/mp_rank_00_model_states.pt', 'cpu')['module']\n",
    "# ds_checkpoints = torch.load('/nfs/zhujinguo/codes/xmodaler/work_dirs/deepspeed_moe/BERT_L12_H768_experiments/bertbase_womoe/89999/mp_rank_00_model_states.pt', 'cpu')['module']\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d5f02e7e-95eb-4386-ac68-f8be0f7ac3c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "oldkeys = list(ds_checkpoints.keys())\n",
    "# oldkeys[:20]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6a9c4928-e808-4e5d-a8cb-c4d133cc9c6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "mapping_dict = {\n",
    "\n",
    "    'encoder.': 'fused_encoder.',\n",
    "    # 'attention.self.qkv_proj.weight': 'self_attn.in_proj_weight',\n",
    "    # 'attention.self.qkv_proj.bias': 'self_attn.in_proj_bias',\n",
    "    'attention.self.qkv_proj': 'self_attn.qkv_proj',\n",
    "    'deepspeed_moe.gate': 'gate',\n",
    "    'deepspeed_moe.experts': 'experts',\n",
    "    'attention.output.dense': 'self_attn.dense',\n",
    "    'attention_output.residual_scale': 'gamma_1',\n",
    "    'ffn.dense.': 'linear1.',\n",
    "    'ffn.dense2.': 'linear2.',\n",
    "    'ffn_output.residual_scale': 'gamma_2',\n",
    "    'LayerNormModules.0.': 'norm1.',\n",
    "    'LayerNormModules.1.': 'norm2.',\n",
    "    'predictor.': 'loss_prepare.',\n",
    "    \n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "9a84319a-d13c-411a-bd21-c0a9c1adb872",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_checkpoint = {}\n",
    "for k, v in ds_checkpoints.items():\n",
    "    if k.endswith('residual_scale'):\n",
    "        v = v.squeeze(1).squeeze(0)\n",
    "        # print(v.shape )\n",
    "        \n",
    "    if k.startswith('visual_embed'):\n",
    "        continue\n",
    "    \n",
    "        \n",
    "    for origin_str, target_str in mapping_dict.items():\n",
    "        if origin_str in k:\n",
    "            k = k.replace(origin_str, target_str)\n",
    "    # merge type embedding in video_embed \n",
    "    # if k=='video_embed.embeddings.bias':\n",
    "    #     v = v + ds_checkpoints['video_embed.embeddings_type.weight'][0]\n",
    "\n",
    "    new_checkpoint[k] = v.float()\n",
    "    # if 'wg' in k:\n",
    "    #     print(f'{k}, {v}')\n",
    "# new_checkpoint['video_embed.embeddings.bias'] = new_checkpoint['video_embed.embeddings.bias'] + new_checkpoint['video_embed.embeddings_type.weight'][0]\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "b5c999e3-e4b1-4949-b89e-4ee2259db8fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save({ 'model': new_checkpoint}, 'pretrained_models/uni-perceiver-moe-tiny-L12-H192-224size-pretrained-withvtype.pth')\n",
    "\n",
    "\n",
    "# torch.save({ 'model': new_checkpoint}, 'pretrained_models/uni-perceiver-moe-tiny-L12-H192-224size-pretrained.pth')\n",
    "# torch.save({ 'model': new_checkpoint}, 'pretrained_models/uni-perceiver-moe-base-L12-H768-224size-pretrained.pth')\n",
    "# torch.save({ 'model': new_checkpoint}, 'pretrained_models/uni-perceiver-base-L12-H768-224size-pretrained-custom-attn-module.pth')\n",
    "\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
